Addition is All You Need For Energy-Efficient Language Models¶
%load_ext nvcc4jupyter
Source files will be saved in "/tmp/tmpbm17i6ti".
General Reasoning¶
LLMs, like most deep-learning systems use a lot of floating-point multiplication. By default 32-bit floats, but it's become popular to reduce the precision to 8-bit floats during the forwards pass. Floating point multiplication operations use more energy than integer addition operations, so it would be good if we could use those instead. This paper presents a way to approximate 32-bit floating point multiplication using only integer addition. It's slightly less precise than 32-bit floats, but way more precise than 8-bit floats, so it should be usable (and more energy efficient).
Method¶
Current floating point multiplication works something like this:
$Mul(x, y) = (1+x_m + y_M + x_my_m)(2^{x_e+y_e})$
Where $x_e$ is the exponent part of the floating point number $x$, and $x_m$ is the fraction part. Same goes for $y_e$ and $y_m$.
Then we do an xor
of the two sign
bits to determine the final sign of the output.
The slowest part of this is actually the $x_my_m$ part, as it takes $O(m^2)$.
Instead of doing it that way, we can use the following piecewise-function
$L-Mul(x, y) = (1 + x_m +y_m + 2^{-l(m)})2^{x_e+y_e}$
where
$l(m) = \begin{cases}
m & \text{if } m \le 3,\\
3 & \text{if } m=4,\\
4 & \text{if } m > 4.
\end{cases}$
Which in practice would actually be really simple to implement. Where $x$ and $y$ are bit arrays representing floating-point numbers:
$L-mul(x,y)[0] = x[0] \oplus y[0]$
$L-mul(x,y)[1:] = x[1:] + y[1:] - offset$
Implementation¶
Now let's try and come up with some PTX assembly code we can use to simulate this.
; Create some intermediate registers
.reg u32 r0;
.reg u32 r1;
.reg u32 r2;
.reg u32 s1;
.reg u32 s2;
; Move our two floating point numbers into
; integer registers
mov.b32 r1, $1;
mov.b32 r2, $2;
; And the first to get the sign byte
; and non-sign bytes respectively
and.b32 s1, r1, 0x80000000;
and.b32 r1, r1, 0x7FFFFFFF;
; Do the same for the second register
and.b32 s2, r2, 0x80000000;
and.b32 r2, r2, 0x7FFFFFFF;
; Now we perform integer addition
; storing output in r0
add.u32 r0, r1, r2;
; Subtract the offset (determined
; experimentally)
sub.u32 r0, r0, 0x3F780000;
; And if it overflowed into the sign byte
; cut it off again
and.b32 r0, r0, 0x7FFFFFFF;
; Then we'll determine the sign with an XOR
; (the normal way)
xor.b32 s1, s1, s2;
; And tag it back onto the final answer
add.u32 r0, r0, s1;
; Then just move the result to the output register
mov.b32 $0, r0;
It's worth noting that the paper provides this PTX assembly for sake of emulating the process which ideally should be done with real hardware and not just in assembly. I'm going to time it anyhow, just to see what it's like (maybe it's somewhat better), but it shouldn't be considered the end-all-be-all of the result. What does matter here though is the overall accuracy, as that should be the same as the hardware implementation.
%%cuda
#include <stdio.h>
__device__ float lmul(float x, float y) {
float z = 0;
asm("{"
".reg .u32 r0;"
".reg .u32 r1;"
".reg .u32 r2;"
".reg .u32 s1;"
".reg .u32 s2;"
" mov.b32 r1, %1;"
" mov.b32 r2, %2;"
" and.b32 s1, r1, 0x80000000;"
" and.b32 r1, r1, 0x7FFFFFFF;"
" and.b32 s2, r2, 0x80000000;"
" and.b32 r2, r2, 0x7FFFFFFF;"
" add.u32 r0, r1, r2;"
" sub.u32 r0, r0, 0x3F780000;"
" and.b32 r0, r0, 0x7FFFFFFF;"
" xor.b32 s1, s1, s2;"
" add.u32 r0, r0, s1;"
" mov.b32 %0, r0;"
"}"
: "=f"(z) : "f"(x), "f"(y));
return z;
}
__device__ float rmul(float x, float y) {
return x*y;
}
__global__ void operate_lmul(float* x1, int size_x) {
for (int i =0; i < size_x; i+= 2) {
x1[i] = lmul(x1[i], x1[i+1]);
}
}
__global__ void operate_rmul(float* x2, int size_x) {
for (int i =0; i < size_x; i+= 2) {
x2[i] = rmul(x2[i], x2[i+1]);
}
}
int main(void) {
//# Create a few arrays
unsigned long int N = 1<<20;
float *device_x1, *device_x2, *host_x, *host_y;
cudaMalloc(&device_x1, N*sizeof(float));
cudaMalloc(&device_x2, N*sizeof(float));
host_x = (float*)malloc(N*sizeof(float));
host_y = (float*)malloc(N*sizeof(float));
//# Create some CUDA events so we can keep track of time
cudaEvent_t start, stop;
cudaEventCreate(&start);
cudaEventCreate(&stop);
//# Initialize the host x as all random numbers
for (int i = 0; i < N; i++) {
//# This will give relatively low values
host_x[i] = (float)random() / (float)random();
}
//# Then copy it over to the device list
cudaMemcpy(device_x1, host_x, N*sizeof(float), cudaMemcpyHostToDevice);
cudaMemcpy(device_x2, host_x, N*sizeof(float), cudaMemcpyHostToDevice);
//# Run our cool function on them
cudaEventRecord(start);
operate_lmul<<<1, 1>>>(device_x1, N);
cudaEventRecord(stop);
//# This function stops the CPU execution until the "stop" event is recorded
cudaEventSynchronize(stop);
//# Then we'll compute the elapsed time using another CUDA function
float milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("Lmul operation completed in %fms\n", milliseconds);
//# Run our cool function on them
cudaEventRecord(start);
operate_rmul<<<1, 1>>>(device_x2, N);
cudaEventRecord(stop);
//# This function stops the CPU execution until the "stop" event is recorded
cudaEventSynchronize(stop);
//# Then we'll compute the elapsed time using another CUDA function
milliseconds = 0;
cudaEventElapsedTime(&milliseconds, start, stop);
printf("Rmul operation completed in %fms\n", milliseconds);
//# Then copy them over again
cudaMemcpy(host_x, device_x1, N*sizeof(float), cudaMemcpyDeviceToHost);
cudaMemcpy(host_y, device_x2, N*sizeof(float), cudaMemcpyDeviceToHost);
//# And now we'll compute the average error of linear-mul vs regular-mul
float error = 0;
float accur = 0;
for (int i =0; i < N; i++) {
error += abs(host_x[i] - host_y[i]);
accur += abs(host_x[i] - host_y[i])/host_y[i];
}
error = error/N;
accur = accur/N;
accur *= 100;
printf("Average error was %f\nOn average error was %.3f%% off from the real answer.\n", error, accur);
//# We can now deallocate it
cudaFree(device_x1);
cudaFree(device_x2);
free(host_x);
}
Lmul operation completed in 109.921791ms Rmul operation completed in 40.054783ms Average error was 0.392336 On average error was 1.350% off from the real answer.
Results¶
Well it's a lot slower (to be expected with how we've implemented it), but it is surprisingly accurate (at least for floats of this size). Overall I don't think there's any benefit we could get from this until hardware is produced that actually operates on it, but the idea is very interesting, and I hope we can see it put into effect in a GPU sometime soon.